{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "635d8ebb",
      "metadata": {},
      "source": [
        "# Adaptive RAG\n",
        "\n",
        "- Author: [The LangChain Open Tutorial team](https://github.com/langchainopentutorial)\n",
        "- Design:\n",
        "- Peer Review:\n",
        "- This is a part of [LangChain Open Tutorial](https://github.com/LangChain-OpenTutorial/LangChain-OpenTutorial)\n",
        "\n",
        "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LangChain-OpenTutorial/LangChain-OpenTutorial/blob/main/99-TEMPLATE/00-BASE-TEMPLATE-EXAMPLE.ipynb) [![Open in GitHub](https://img.shields.io/badge/Open%20in%20GitHub-181717?style=flat-square&logo=github&logoColor=white)](https://github.com/LangChain-OpenTutorial/LangChain-OpenTutorial/blob/main/99-TEMPLATE/00-BASE-TEMPLATE-EXAMPLE.ipynb)\n",
        "\n",
        "## Overview\n",
        "\n",
        "This tutorial covers the implementation of Adaptive Retrieval-Augmented Generation (Adaptive RAG).\n",
        "\n",
        "Adaptive RAG is a strategy that combines query analysis and active/self-modifying RAG to retrieve and generate information from diverse data sources.\n",
        "\n",
        "In this tutorial, we use LangGraph to implement routing between web browsing and self-modifying RAGs.\n",
        "\n",
        "![adaptive-rag](./assets/langgraph-adaptive-rag.png)\n",
        "\n",
        "**Adaptive RAG** ​​is a strategy of **RAG**, combining Query Construction and Self-Reflective RAG.\n",
        "\n",
        "[Thesis: Adaptive-RAG: Learning to Adapt Retrieval-Augmented Large Language Models through Question Complexity](https://arxiv.org/abs/2403.14403) performs the following routing through query analysis:\n",
        "\n",
        "- `No Retrieval`\n",
        "- `Single-shot RAG`\n",
        "- `Iterative RAG`\n",
        "\n",
        "In this tutorial, we implement an example using LangGraph to perform the following routing:\n",
        "\n",
        "- **Web Search**: Used for questions related to latest events\n",
        "- **Self-Reflective RAG**: Used for questions related to indexes\n",
        "\n",
        "### Table of Contents\n",
        "\n",
        "- [Overview](#overview)\n",
        "- [Environment Setup](#environment-setup)\n",
        "- [Create a basic PDF-based Retrieval Chain](#create-a-basic-pdf-based-retrieval-chain)\n",
        "- [Query routing and document evaluation](#query-routing-and-document-evaluation)\n",
        "- [Tools](#tools)\n",
        "- [Graph Construction](#graph-construction)\n",
        "- [Define Graph Flows](#define-graph-flows)\n",
        "- [Define Nodes](#define-nodes)\n",
        "- [Graph Construction](#graph-construction)\n",
        "- [Execute Graph](#execute-graph)\n",
        "\n",
        "### References\n",
        "\n",
        "- [LangChain: Query Construction](https://blog.langchain.dev/query-construction/)\n",
        "- [LangGraph: Self-Reflective RAG](https://blog.langchain.dev/agentic-rag-with-langgraph/)\n",
        "- [Adaptive-RAG: Learning to Adapt Retrieval-Augmented Large Language Models through Question Complexity](https://arxiv.org/abs/2403.14403)\n",
        "----"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "c6c7aba4",
      "metadata": {},
      "source": [
        "## Environment Setup\n",
        "\n",
        "Set up the environment. You may refer to [Environment Setup](https://wikidocs.net/257836) for more details.\n",
        "\n",
        "**[Note]**\n",
        "- `langchain-opentutorial` is a package that provides a set of easy-to-use environment setup, useful functions and utilities for tutorials. \n",
        "- You can checkout the [`langchain-opentutorial`](https://github.com/LangChain-OpenTutorial/langchain-opentutorial-pypi) for more details."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "id": "21943adb",
      "metadata": {},
      "outputs": [],
      "source": [
        "%%capture --no-stderr\n",
        "!pip install langchain-opentutorial"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "id": "f25ec196",
      "metadata": {},
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\n",
            "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.3.1\u001b[0m\n",
            "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
          ]
        }
      ],
      "source": [
        "# Install required packages\n",
        "from langchain_opentutorial import package\n",
        "\n",
        "package.install(\n",
        "    [\n",
        "        \"langsmith\",\n",
        "        \"langchain\",\n",
        "        \"langchain_core\",\n",
        "        \"langchain-anthropic\",\n",
        "        \"langchain_community\",\n",
        "        \"langchain_text_splitters\",\n",
        "        \"langchain_openai\",\n",
        "    ],\n",
        "    verbose=False,\n",
        "    upgrade=False,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "id": "7f9065ea",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Environment variables have been set successfully.\n"
          ]
        }
      ],
      "source": [
        "# Set environment variables\n",
        "from langchain_opentutorial import set_env\n",
        "\n",
        "set_env(\n",
        "    {\n",
        "        \"OPENAI_API_KEY\": \"\",\n",
        "        \"LANGCHAIN_API_KEY\": \"\",\n",
        "        \"LANGCHAIN_TRACING_V2\": \"true\",\n",
        "        \"LANGCHAIN_ENDPOINT\": \"https://api.smith.langchain.com\",\n",
        "        \"LANGCHAIN_PROJECT\": \"Adaptive-RAG\",  # Please set it the same as title\n",
        "    }\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "690a9ae0",
      "metadata": {},
      "source": [
        "You can alternatively set API keys such as `OPENAI_API_KEY` in a `.env` file and load them.\n",
        "\n",
        "**[Note]** This is not necessary if you've already set the required API keys in previous steps."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4f99b5b6",
      "metadata": {},
      "outputs": [],
      "source": [
        "# Load API keys from .env file\n",
        "from dotenv import load_dotenv\n",
        "\n",
        "load_dotenv(override=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "616661ad",
      "metadata": {},
      "source": [
        "## Reference (related to image file name)\n",
        "\n",
        "When writing a tutorial file, there are cases where images are added to `assets` and added as markdown.\n",
        "\n",
        "At this time, we are providing a guide to ensure uniformity in image file names.\n",
        "\n",
        "**Image file name**\n",
        "1. All image file names should be written in **lowercase English letters**.\n",
        "2. There should be no spaces in the image file. Replace spaces with `-` hyphens.\n",
        "\n",
        "jupyter notebook file name + image title + number if necessary (01, 02, 03, ...)\n",
        "\n",
        "example)\n",
        "In case of `10-LangGraph-Self-RAG.ipynb`\n",
        "\n",
        "Image file name: \n",
        "- `10-langgraph-self-rag-flow-explanation.png`: OK\n",
        "- `10-langgraph-self-rag-flow-explanation-01.png`: OK\n",
        "- `10-langgraph-self-rag-flow-explanation-02.png`: OK"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "17efec71",
      "metadata": {},
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "markdown",
      "id": "aa00c3f4",
      "metadata": {},
      "source": [
        "## Create a basic PDF-based Retrieval Chain\n",
        "\n",
        "Here, we create a Retrieval Chain based on a PDF document. This is the Retrieval Chain with the simplest structure.\n",
        "\n",
        "However, in LangGraph, Retirever and Chain are created separately. Only then can detailed processing be performed for each node.\n",
        "\n",
        "**reference**\n",
        "- As this was covered in the previous tutorial, detailed explanation will be omitted."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "id": "69cb77da",
      "metadata": {},
      "outputs": [],
      "source": [
        "from rag.pdf import PDFRetrievalChain\n",
        "\n",
        "# Load the PDF document.\n",
        "pdf = PDFRetrievalChain([\"data/SPRI_AI_Brief_December 2023_F.pdf\"]).create_chain()\n",
        "\n",
        "# create retriever\n",
        "pdf_retriever = pdf.retriever\n",
        "\n",
        "# create chain\n",
        "pdf_chain = pdf.chain"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "2b2fc536",
      "metadata": {},
      "source": [
        "## Query routing and document evaluation\n",
        "\n",
        "In this step, **query routing** and **document evaluation** are performed. This process is an important part of **Adaptive RAG**, contributing to efficient information retrieval and creation.\n",
        "\n",
        "- **Query Routing**: Analyzes user queries and routes them to appropriate information sources. This allows you to set the optimal search path for the purpose of your query.\n",
        "- **Document Evaluation**: Evaluates the quality and relevance of searched documents to improve the accuracy of the final results. \n",
        "\n",
        "This step supports the core functionality of **Adaptive RAG** ​​and aims to provide accurate and reliable information."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "id": "1b78d33f",
      "metadata": {},
      "outputs": [],
      "source": [
        "from typing import Literal\n",
        "\n",
        "from langchain_core.prompts import ChatPromptTemplate\n",
        "from pydantic import BaseModel, Field\n",
        "from langchain_openai import ChatOpenAI\n",
        "from langchain_teddynote.models import get_model_name, LLMs\n",
        "\n",
        "# Get latest LLM model name\n",
        "MODEL_NAME = get_model_name(LLMs.GPT4)\n",
        "\n",
        "\n",
        "# Data model that routes user queries to the most relevant data sources\n",
        "class RouteQuery(BaseModel):\n",
        "    \"\"\"Route a user query to the most relevant datasource.\"\"\"\n",
        "\n",
        "    # Literal type field for data source selection\n",
        "    datasource: Literal[\"vectorstore\", \"web_search\"] = Field(\n",
        "        ...,\n",
        "        description=\"Given a user question choose to route it to web search or a vectorstore.\",\n",
        "    )\n",
        "\n",
        "\n",
        "# Generate structured output through LLM initialization and function calls\n",
        "llm = ChatOpenAI(model=MODEL_NAME, temperature=0)\n",
        "structured_llm_router = llm.with_structured_output(RouteQuery)\n",
        "\n",
        "# Create prompt templates including system messages and user questions\n",
        "system = \"\"\"You are an expert at routing a user question to a vectorstore or web search.\n",
        "The vectorstore contains documents related to DEC 2023 AI Brief Report(SPRI) with Samsung Gause, Anthropic, etc.\n",
        "Use the vectorstore for questions on these topics. Otherwise, use web-search.\"\"\"\n",
        "\n",
        "# Create a prompt template for routing\n",
        "route_prompt = ChatPromptTemplate.from_messages(\n",
        "    [\n",
        "        (\"system\", system),\n",
        "        (\"human\", \"{question}\"),\n",
        "    ]\n",
        ")\n",
        "\n",
        "# Create a question router by combining the prompt template and structured LLM router\n",
        "question_router = route_prompt | structured_llm_router"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "c9e4d831",
      "metadata": {},
      "source": [
        "Next, we will test the query routing results and check the results."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0874c14b",
      "metadata": {},
      "outputs": [],
      "source": [
        "# Questions requiring document search\n",
        "print(\n",
        "    question_router.invoke(\n",
        "        {\"question\": \"What is the name of the generative AI created by Samsung Electronics in AI Brief?\"}\n",
        "    )\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a2d22b26",
      "metadata": {},
      "outputs": [],
      "source": [
        "# Questions that require web search\n",
        "print(question_router.invoke({\"question\": \"Find the best dim sum restaurant in Pangyo\"}))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "5fc43b99",
      "metadata": {},
      "source": [
        "### Retrieval Grader\n",
        "\n",
        "About the search evaluator..."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "id": "d1221d80",
      "metadata": {},
      "outputs": [],
      "source": [
        "from pydantic import BaseModel, Field\n",
        "from langchain_openai import ChatOpenAI\n",
        "from langchain_core.prompts import ChatPromptTemplate\n",
        "\n",
        "\n",
        "# Define data model for document evaluation\n",
        "class GradeDocuments(BaseModel):\n",
        "    \"\"\"Binary score for relevance check on retrieved documents.\"\"\"\n",
        "\n",
        "    binary_score: str = Field(\n",
        "        description=\"Documents are relevant to the question, 'yes' or 'no'\"\n",
        "    )\n",
        "\n",
        "\n",
        "# Generate structured output through LLM initialization and function calls\n",
        "llm = ChatOpenAI(model=MODEL_NAME, temperature=0)\n",
        "structured_llm_grader = llm.with_structured_output(GradeDocuments)\n",
        "\n",
        "# Create prompt templates including system messages and user questions\n",
        "system = \"\"\"You are a grader assessing relevance of a retrieved document to a user question. \\n \n",
        "    If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \\n\n",
        "    It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \\n\n",
        "    Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.\"\"\"\n",
        "\n",
        "grade_prompt = ChatPromptTemplate.from_messages(\n",
        "    [\n",
        "        (\"system\", system),\n",
        "        (\"human\", \"Retrieved document: \\n\\n {document} \\n\\n User question: {question}\"),\n",
        "    ]\n",
        ")\n",
        "\n",
        "# Create a document search result evaluator\n",
        "retrieval_grader = grade_prompt | structured_llm_grader"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "927cac10",
      "metadata": {},
      "source": [
        "Evaluate the **document search result** using the `retrieval_grader` you created."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "id": "2fa5e0d7",
      "metadata": {},
      "outputs": [],
      "source": [
        "# User question settings\n",
        "question = \"What is the name of the generative AI created by Samsung Electronics?\"\n",
        "\n",
        "# Search related documents for your question\n",
        "docs = pdf_retriever.invoke(question)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ef397b71",
      "metadata": {},
      "outputs": [],
      "source": [
        "# Get the contents of the searched document\n",
        "retrieved_doc = docs[1].page_content\n",
        "\n",
        "# Print evaluation results\n",
        "print(retrieval_grader.invoke({\"question\": question, \"document\": retrieved_doc}))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "id": "dce41bfd",
      "metadata": {},
      "outputs": [],
      "source": [
        "# Filtering code example\n",
        "filtered_docs = []\n",
        "\n",
        "\n",
        "for doc in docs:\n",
        "   # Check document evaluation results\n",
        "    result = retrieval_grader.invoke(\n",
        "        {\n",
        "            \"question\": question,\n",
        "            \"document\": doc.page_content,\n",
        "        }\n",
        "    )\n",
        "    # Filter only relevant documents\n",
        "    if result.binary_score == \"yes\":\n",
        "        filtered_docs.append(doc)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "54dce7a1",
      "metadata": {},
      "source": [
        "### Create a RAG chain to generate answers"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "id": "992ef15a",
      "metadata": {},
      "outputs": [],
      "source": [
        "from langchain import hub\n",
        "from langchain_core.output_parsers import StrOutputParser\n",
        "from langchain_openai import ChatOpenAI\n",
        "\n",
        "# Import prompts from LangChain Hub (RAG prompts can be freely modified)\n",
        "prompt = hub.pull(\"teddynote/rag-prompt\")\n",
        "\n",
        "# Initialize LLM\n",
        "llm = ChatOpenAI(model_name=MODEL_NAME, temperature=0)\n",
        "\n",
        "\n",
        "# Document formatting function\n",
        "def format_docs(docs):\n",
        "    return \"\\n\\n\".join(\n",
        "        [\n",
        "            f'<document><content>{doc.page_content}</content><source>{doc.metadata[\"source\"]}</source><page>{doc.metadata[\"page\"]+1}</page></document>'\n",
        "            for doc in docs\n",
        "        ]\n",
        "    )\n",
        "\n",
        "\n",
        "# Create RAG chain\n",
        "rag_chain = prompt | llm | StrOutputParser()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "0fbc96e3",
      "metadata": {},
      "source": [
        "Now we generate the answer by passing the question to the `rag_chain` we created."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "f8d16e04",
      "metadata": {},
      "outputs": [],
      "source": [
        "# Pass questions to the RAG chain to generate answers\n",
        "generation = rag_chain.invoke({\"context\": format_docs(docs), \"question\": question})\n",
        "print(generation)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a0e9f601",
      "metadata": {},
      "source": [
        "### Added Hallucination checker for answers"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "id": "40ec0e97",
      "metadata": {},
      "outputs": [],
      "source": [
        "# Define data model for hallucination check\n",
        "class GradeHallucinations(BaseModel):\n",
        "    \"\"\"Binary score for hallucination present in generation answer.\"\"\"\n",
        "\n",
        "    binary_score: str = Field(\n",
        "        description=\"Answer is grounded in the facts, 'yes' or 'no'\"\n",
        "    )\n",
        "\n",
        "\n",
        "# LLM initialization through function call\n",
        "llm = ChatOpenAI(model=MODEL_NAME, temperature=0)\n",
        "structured_llm_grader = llm.with_structured_output(GradeHallucinations)\n",
        "\n",
        "# Prompt settings\n",
        "system = \"\"\"You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \\n \n",
        "    Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts.\"\"\"\n",
        "\n",
        "# Create prompt template\n",
        "hallucination_prompt = ChatPromptTemplate.from_messages(\n",
        "    [\n",
        "        (\"system\", system),\n",
        "        (\"human\", \"Set of facts: \\n\\n {documents} \\n\\n LLM generation: {generation}\"),\n",
        "    ]\n",
        ")\n",
        "\n",
        "# Create a hallucination evaluator\n",
        "hallucination_grader = hallucination_prompt | structured_llm_grader"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "8550b7cf",
      "metadata": {},
      "source": [
        "Use the `hallucination_grader` you created to evaluate whether the generated answers are hallucinations."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "cb593684",
      "metadata": {},
      "outputs": [],
      "source": [
        "# Use the evaluator to evaluate whether the generated answers are hallucinatory\n",
        "hallucination_grader.invoke({\"documents\": docs, \"generation\": generation})"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "id": "110eb9b0",
      "metadata": {},
      "outputs": [],
      "source": [
        "class GradeAnswer(BaseModel):\n",
        "    \"\"\"Binary scoring to evaluate the appropriateness of answers to questions\"\"\"\n",
        "\n",
        "    binary_score: str = Field(\n",
        "        description=\"Indicate 'yes' or 'no' whether the answer solves the question\"\n",
        "    )\n",
        "\n",
        "\n",
        "# LLM initialization through function call\n",
        "llm = ChatOpenAI(model=MODEL_NAME, temperature=0)\n",
        "structured_llm_grader = llm.with_structured_output(GradeAnswer)\n",
        "\n",
        "# Prompt settings\n",
        "system = \"\"\"You are a grader assessing whether an answer addresses / resolves a question \\n \n",
        "     Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question.\"\"\"\n",
        "answer_prompt = ChatPromptTemplate.from_messages(\n",
        "    [\n",
        "        (\"system\", system),\n",
        "        (\"human\", \"User question: \\n\\n {question} \\n\\n LLM generation: {generation}\"),\n",
        "    ]\n",
        ")\n",
        "\n",
        "# Create an answer evaluator by combining a prompt template and a structured LLM evaluator\n",
        "answer_grader = answer_prompt | structured_llm_grader"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "66a26ad6",
      "metadata": {},
      "outputs": [],
      "source": [
        "# Use the evaluator to evaluate whether the generated answer solves the question\n",
        "answer_grader.invoke({\"question\": question, \"generation\": generation})"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a9fc11dd",
      "metadata": {},
      "source": [
        "### Query Rewriter"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "id": "e9df325a",
      "metadata": {},
      "outputs": [],
      "source": [
        "from langchain_openai import ChatOpenAI\n",
        "from langchain_core.prompts import ChatPromptTemplate\n",
        "from langchain_core.output_parsers import StrOutputParser\n",
        "\n",
        "# Initialize LLM\n",
        "llm = ChatOpenAI(model=MODEL_NAME, temperature=0)\n",
        "\n",
        "# Definition of Query Rewriter prompt (can be freely modified)\n",
        "system = \"\"\"You a question re-writer that converts an input question to a better version that is optimized \\n \n",
        "for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning.\"\"\"\n",
        "\n",
        "# Create a Query Rewriter prompt template\n",
        "re_write_prompt = ChatPromptTemplate.from_messages(\n",
        "    [\n",
        "        (\"system\", system),\n",
        "        (\n",
        "            \"human\",\n",
        "            \"Here is the initial question: \\n\\n {question} \\n Formulate an improved question.\",\n",
        "        ),\n",
        "    ]\n",
        ")\n",
        "\n",
        "# Create Query Rewriter\n",
        "question_rewriter = re_write_prompt | llm | StrOutputParser()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "0abd3e83",
      "metadata": {},
      "source": [
        "Create an improved question by passing the question to the created `question_rewriter`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "c6eb92e7",
      "metadata": {},
      "outputs": [],
      "source": [
        "# Generate improved questions by passing the questions to the question rewriter\n",
        "question_rewriter.invoke({\"question\": question})"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "d8d5ee42",
      "metadata": {},
      "source": [
        "##Tools\n",
        "\n",
        "### Web search tools\n",
        "\n",
        "The **Web Search Tool** is an important component of **Adaptive RAG** ​​and is used to retrieve up-to-date information. This tool helps users get quick and accurate answers to questions related to current events.\n",
        "\n",
        "- **Settings**: Set up your web search tools so they are ready to search for the latest information.\n",
        "- **Perform Search**: Search the web for relevant information based on your query.\n",
        "- **Result Analysis**: Analyzes search results to provide information most appropriate to the user's question."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "id": "e004263c",
      "metadata": {},
      "outputs": [],
      "source": [
        "from langchain_teddynote.tools.tavily import TavilySearch\n",
        "\n",
        "# Create a web search tool\n",
        "web_search_tool = TavilySearch(max_results=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "63d60abe",
      "metadata": {},
      "source": [
        "Run the web search tool and check the results."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "c13be8f3",
      "metadata": {},
      "outputs": [],
      "source": [
        "# Call web search tool\n",
        "result = web_search_tool.search(\"Please tell me the Teddy Note Wikidocs LangChain tutorial URL\")\n",
        "print(result)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1904c95c",
      "metadata": {},
      "outputs": [],
      "source": [
        "# Check the first result of web search results\n",
        "result[0]"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "1ac37855",
      "metadata": {},
      "source": [
        "## Graph Construction"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "70ab91c2",
      "metadata": {},
      "source": [
        "### Defining graph states"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "id": "6d23ab6f",
      "metadata": {},
      "outputs": [],
      "source": [
        "from typing import List\n",
        "from typing_extensions import TypedDict, Annotated\n",
        "\n",
        "\n",
        "# Define the state of the graph\n",
        "class GraphState(TypedDict):\n",
        "    \"\"\"\n",
        "    A data model representing the state of the graph\n",
        "\n",
        "    Attributes:\n",
        "        question: question\n",
        "        generation: LLM generated answers\n",
        "        documents: document list\n",
        "    \"\"\"\n",
        "\n",
        "    question: Annotated[str, \"User question\"]\n",
        "    generation: Annotated[str, \"LLM generated answer\"]\n",
        "    documents: Annotated[List[str], \"List of documents\"]"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "f266cc42",
      "metadata": {},
      "source": [
        "## Define Graph Flows\n",
        "\n",
        "Clarify how **Adaptive RAG** ​​works by defining **Graph Flow**. This step establishes the states and transitions of the graph to increase the efficiency of query processing.\n",
        "\n",
        "- **State Definition**: Track the progress of a query by clearly defining each state in the graph.\n",
        "- **Set Transitions**: Set transitions between states to ensure queries follow the appropriate path.\n",
        "- **Flow Optimization**: Optimize the flow of the graph to improve the accuracy of information retrieval and creation."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "633bf00c",
      "metadata": {},
      "source": [
        "### Define Nodes\n",
        "\n",
        "Define the nodes to utilize.\n",
        "\n",
        "- `retrieve`: document retrieval node\n",
        "- `generate`: answer generation node\n",
        "- `grade_documents`: document relevance evaluation node\n",
        "- `transform_query`: question rewrite node\n",
        "- `web_search`: Web search node\n",
        "- `route_question`: question routing node\n",
        "- `decide_to_generate`: answer generation decision node\n",
        "- `hallucination_check`: hallucination evaluation node"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 24,
      "id": "ee6f34d0",
      "metadata": {},
      "outputs": [],
      "source": [
        "from langchain_core.documents import Document\n",
        "\n",
        "\n",
        "# Document search node\n",
        "def retrieve(state):\n",
        "    print(\"==== [RETRIEVE] ====\")\n",
        "    question = state[\"question\"]\n",
        "\n",
        "    # Perform document search\n",
        "    documents = pdf_retriever.invoke(question)\n",
        "    return {\"documents\": documents}\n",
        "\n",
        "\n",
        "# Answer generation node\n",
        "def generate(state):\n",
        "    print(\"==== [GENERATE] ====\")\n",
        "    # Get questions and document search results\n",
        "    question = state[\"question\"]\n",
        "    documents = state[\"documents\"]\n",
        "\n",
        "    # Generate RAG answer\n",
        "    generation = rag_chain.invoke({\"context\": documents, \"question\": question})\n",
        "    return {\"generation\": generation}\n",
        "\n",
        "\n",
        "# Document relevance evaluation node\n",
        "def grade_documents(state):\n",
        "    print(\"==== [CHECK DOCUMENT RELEVANCE TO QUESTION] ====\")\n",
        "    # Get questions and document search results\n",
        "    question = state[\"question\"]\n",
        "    documents = state[\"documents\"]\n",
        "\n",
        "    # Calculate relevance score for each document\n",
        "    filtered_docs = []\n",
        "    for d in documents:\n",
        "        score = retrieval_grader.invoke(\n",
        "            {\"question\": question, \"document\": d.page_content}\n",
        "        )\n",
        "        grade = score.binary_score\n",
        "        if grade == \"yes\":\n",
        "            print(\"---GRADE: DOCUMENT RELEVANT---\")\n",
        "            # Add relevant documents\n",
        "            filtered_docs.append(d)\n",
        "        else:\n",
        "            # Skip irrelevant documents\n",
        "            print(\"---GRADE: DOCUMENT NOT RELEVANT---\")\n",
        "            continue\n",
        "    return {\"documents\": filtered_docs}\n",
        "\n",
        "\n",
        "# Question rewrite node\n",
        "def transform_query(state):\n",
        "    print(\"==== [TRANSFORM QUERY] ====\")\n",
        "    # Get questions and document search results\n",
        "    question = state[\"question\"]\n",
        "    documents = state[\"documents\"]\n",
        "\n",
        "    # Rewrite the question\n",
        "    better_question = question_rewriter.invoke({\"question\": question})\n",
        "    return {\"question\": better_question}\n",
        "\n",
        "\n",
        "# web search node\n",
        "def web_search(state):\n",
        "    print(\"==== [WEB SEARCH] ====\")\n",
        "    # Get questions and document search results\n",
        "    question = state[\"question\"]\n",
        "\n",
        "    # Perform a web search\n",
        "    web_results = web_search_tool.invoke({\"query\": question})\n",
        "    web_results_docs = [\n",
        "        Document(\n",
        "            page_content=web_result[\"content\"],\n",
        "            metadata={\"source\": web_result[\"url\"]},\n",
        "        )\n",
        "        for web_result in web_results\n",
        "    ]\n",
        "\n",
        "    return {\"documents\": web_results_docs}\n",
        "\n",
        "\n",
        "# Question routing node\n",
        "def route_question(state):\n",
        "    print(\"==== [ROUTE QUESTION] ====\")\n",
        "    # Get questions\n",
        "    question = state[\"question\"]\n",
        "    # Question routing\n",
        "    source = question_router.invoke({\"question\": question})\n",
        "    # Node routing based on question routing results\n",
        "    if source.datasource == \"web_search\":\n",
        "        print(\"==== [ROUTE QUESTION TO WEB SEARCH] ====\")\n",
        "        return \"web_search\"\n",
        "    elif source.datasource == \"vectorstore\":\n",
        "        print(\"==== [ROUTE QUESTION TO VECTORSTORE] ====\")\n",
        "        return \"vectorstore\"\n",
        "\n",
        "\n",
        "# Document relevance evaluation node\n",
        "def decide_to_generate(state):\n",
        "    print(\"==== [DECISION TO GENERATE] ====\")\n",
        "    # Get document search results\n",
        "    filtered_documents = state[\"documents\"]\n",
        "\n",
        "    if not filtered_documents:\n",
        "        # Rewrite question if all documents are irrelevant\n",
        "        print(\n",
        "            \"==== [DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY] ====\"\n",
        "        )\n",
        "        return \"transform_query\"\n",
        "    else:\n",
        "        # Generate answer if relevant document exists\n",
        "        print(\"==== [DECISION: GENERATE] ====\")\n",
        "        return \"generate\"\n",
        "\n",
        "\n",
        "def hallucination_check(state):\n",
        "    print(\"==== [CHECK HALLUCINATIONS] ====\")\n",
        "    # Get questions and document search results\n",
        "    question = state[\"question\"]\n",
        "    documents = state[\"documents\"]\n",
        "    generation = state[\"generation\"]\n",
        "\n",
        "    # Hallucination Assessment\n",
        "    score = hallucination_grader.invoke(\n",
        "        {\"documents\": documents, \"generation\": generation}\n",
        "    )\n",
        "    grade = score.binary_score\n",
        "\n",
        "    # Check for hallucination\n",
        "    if grade == \"yes\":\n",
        "        print(\"==== [DECISION: GENERATION IS GROUNDED IN DOCUMENTS] ====\")\n",
        "\n",
        "        # Evaluate the relevance of the answer\n",
        "        print(\"==== [GRADE GENERATED ANSWER vs QUESTION] ====\")\n",
        "        score = answer_grader.invoke({\"question\": question, \"generation\": generation})\n",
        "        grade = score.binary_score\n",
        "\n",
        "        # Processing according to relevance evaluation results\n",
        "        if grade == \"yes\":\n",
        "            print(\"==== [DECISION: GENERATED ANSWER ADDRESSES QUESTION] ====\")\n",
        "            return \"relevant\"\n",
        "        else:\n",
        "            print(\"==== [DECISION: GENERATED ANSWER DOES NOT ADDRESS QUESTION] ====\")\n",
        "            return \"not relevant\"\n",
        "    else:\n",
        "        print(\"==== [DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY] ====\")\n",
        "        return \"hallucination\""
      ]
    },
    {
      "cell_type": "markdown",
      "id": "2412119d",
      "metadata": {},
      "source": [
        "## Graph Construction\n",
        "\n",
        "The **Graph Compile** step builds the workflow of **Adaptive RAG** ​​and makes it executable. This process connects each node and edge in the graph to define the overall flow of query processing.\n",
        "\n",
        "- **Node Definition**: Define each node to clarify the states and transitions of the graph.\n",
        "- **Set Edges**: Set edges between nodes to ensure that queries proceed along the appropriate path.\n",
        "- **Build workflow**: Build the entire flow of the graph to maximize the efficiency of information search and creation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 26,
      "id": "c106a028",
      "metadata": {},
      "outputs": [],
      "source": [
        "from langgraph.graph import END, StateGraph, START\n",
        "from langgraph.checkpoint.memory import MemorySaver\n",
        "\n",
        "# Initialize graph state\n",
        "workflow = StateGraph(GraphState)\n",
        "\n",
        "# Node definition\n",
        "workflow.add_node(\"web_search\", web_search) # Web search\n",
        "workflow.add_node(\"retrieve\", retrieve) # Retrieve document\n",
        "workflow.add_node(\"grade_documents\", grade_documents) # Evaluate documents\n",
        "workflow.add_node(\"generate\", generate) # Generate answer\n",
        "workflow.add_node(\"transform_query\", transform_query) # Transform query\n",
        "\n",
        "# Build graph\n",
        "workflow.add_conditional_edges(\n",
        "    START,\n",
        "    route_question,\n",
        "    {\n",
        "        \"web_search\": \"web_search\", # Route to web search\n",
        "        \"vectorstore\": \"retrieve\", # Routing to vectorstore\n",
        "    },\n",
        ")\n",
        "workflow.add_edge(\"web_search\", \"generate\") # Generate answer after web search\n",
        "workflow.add_edge(\"retrieve\", \"grade_documents\") # Evaluate documents after retrieval\n",
        "workflow.add_conditional_edges(\n",
        "    \"grade_documents\",\n",
        "    decide_to_generate,\n",
        "    {\n",
        "        \"transform_query\": \"transform_query\", # Query transformation required\n",
        "        \"generate\": \"generate\", # Can generate answers\n",
        "    },\n",
        ")\n",
        "workflow.add_edge(\"transform_query\", \"retrieve\") # Retrieve documents after transforming query\n",
        "workflow.add_conditional_edges(\n",
        "    \"generate\",\n",
        "    hallucination_check,\n",
        "    {\n",
        "        \"hallucination\": \"generate\", # Regenerate when hallucination occurs\n",
        "        \"relevant\": END, # Pass whether the answer is relevant\n",
        "        \"not relevant\": \"transform_query\", # Transform query if it fails to determine whether the answer is relevant\n",
        "    },\n",
        ")\n",
        "\n",
        "# Graph compilation\n",
        "app = workflow.compile(checkpointer=MemorySaver())"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "748f4505",
      "metadata": {},
      "source": [
        "Visualize the graph."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "46ce79fe",
      "metadata": {},
      "outputs": [],
      "source": [
        "from langchain_teddynote.graphs import visualize_graph\n",
        "\n",
        "visualize_graph(app)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "3fd2739b",
      "metadata": {},
      "source": [
        "## Execute Graph\n",
        "\n",
        "In the **Use Graph** step, the query processing results are checked through the execution of **Adaptive RAG**. This process processes queries along each node and edge of the graph to produce the final result.\n",
        "\n",
        "- **Graph Execution**: Executes the defined graph to follow the flow of the query.\n",
        "- **Check Results**: After running the graph, review the generated results to ensure that the query was processed properly.\n",
        "- **Result Analysis**: Analyze the generated results to evaluate whether they meet the purpose of the query."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b020b140",
      "metadata": {},
      "outputs": [],
      "source": [
        "from langchain_teddynote.messages import stream_graph, random_uuid\n",
        "from langchain_core.runnables import RunnableConfig\n",
        "\n",
        "# config settings (maximum number of recursions, thread_id)\n",
        "config = RunnableConfig(recursion_limit=20, configurable={\"thread_id\": random_uuid()})\n",
        "\n",
        "# Enter question\n",
        "inputs = {\n",
        "    \"question\": \"삼성전자가 개발한 생성형 AI 의 이름은?\",\n",
        "}\n",
        "\n",
        "# Run graph\n",
        "stream_graph(app, inputs, config, [\"agent\", \"rewrite\", \"generate\"])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "e25d23b6",
      "metadata": {},
      "outputs": [],
      "source": [
        "# Enter question\n",
        "inputs = {\n",
        "    \"question\": \"2024년 노벨 문학상 수상자는 누구인가요?\",\n",
        "}\n",
        "\n",
        "# Run graph\n",
        "stream_graph(app, inputs, config, [\"agent\", \"rewrite\", \"generate\"])"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "langchain-kr-lwwSZlnu-py3.11",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.10"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
